""" LangDT (Language Decision Transformer) Implementation """
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import numpy as np

from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim

from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.langdt.planner import LangDTPlanner
from diffgro.utils import llm, print_r, print_y, print_b


class LangDT:
    def __init__(
        self,
        env: gym.Env,
        planner: LangDTPlanner,
        verbose: bool = False,
    ):
        self.env = env
        self.planner = planner.policy
        self.verbose = verbose

        self._setup()

    def _setup(self) -> None:
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.act_dim = get_act_dim(self.env.action_space)
        self.horizon = self.planner.max_length

        # task embedding
        self.task = get_skill_embed(None, self.env.env_name).reshape(1, -1)
        if self.env.domain_name == 'metaworld_complex':
            self.skill = [get_skill_embed(None, task).reshape(1, -1) for task in self.env.full_task_list]

    def reset(self) -> None:
        self.h, self.t = 0, 0
        self.obs_stack = np.zeros((1, self.horizon , self.obs_dim), dtype=np.float32)
        self.act_stack = np.zeros((1, self.horizon, self.act_dim), dtype=np.float32)
        self.t_stack = np.zeros((1, self.horizon), dtype=np.int32)
        self.mask_stack = np.zeros((1, self.horizon), dtype=np.int32)
    
    def predict(self, obs: np.ndarray, deterministic: bool = True):
        # stacking
        self.obs_stack = np.concatenate([self.obs_stack[:,1:,:], obs.reshape((-1, 1, self.obs_dim))], axis=1)
        self.mask_stack = np.concatenate([self.mask_stack[:,1:], np.ones((1, 1), dtype=np.int32)], axis=1)
        
        task, skill = self.task, None
        if self.env.domain_name == 'metaworld_complex':
            task, skill = self.task, self.skill[self.env.success_count]

        # planner inference
        _, plan = self.planner._predict(
            self.obs_stack, self.act_stack, self.t_stack, self.mask_stack, task, skill, deterministic=True)
        act = plan[0, -1].reshape(1, -1)

        # stacking
        self.act_stack = np.concatenate([self.act_stack[:,1:,:], act.reshape((-1, 1, self.act_dim))], axis=1)
        self.t_stack = np.concatenate([self.t_stack[:,1:], np.ones((1,1), dtype=np.int32) * (self.t + 1)], axis=1)
        self.t += 1
        
        act = np.array(act[0].copy())
        return act, None, {}
